import argparse
import numpy as np
import numpy.random as npr
import time
import os
import sys
import pickle
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from torch.optim.lr_scheduler import MultiStepLR
from torchvision import datasets, transforms
from networks import ResNet18BN
from utils import Cutout, in_out_split_noisy, in_out_split_avg_case

# Format time for printing purposes
def get_hms(seconds):
    m, s = divmod(seconds, 60)
    h, m = divmod(m, 60)

    return h, m, s


# Train model for one epoch
#
# example_stats: dictionary containing statistics accumulated over every presentation of example
#
def train(args, model, device, trainset, model_optimizer, epoch, example_stats):
    train_loss = 0.
    correct = 0.
    total = 0.

    model.train()

    # Get permutation to shuffle trainset
    trainset_permutation_inds = npr.permutation(
        np.arange(len(trainset.targets)))

    print('\n=> Training Epoch #%d' % (epoch))

    batch_size = args.batch_size
    for batch_idx, batch_start_ind in enumerate(
            range(0, len(trainset.targets), batch_size)):

        # Get trainset indices for batch
        batch_inds = trainset_permutation_inds[batch_start_ind:
                                               batch_start_ind + batch_size]

        # Get batch inputs and targets, transform them appropriately
        transformed_trainset = []
        for ind in batch_inds:
            transformed_trainset.append(trainset.__getitem__(ind)[0])
        inputs = torch.stack(transformed_trainset)
        targets = torch.LongTensor(
            np.array(trainset.targets)[batch_inds].tolist())

        # Map to available device
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward propagation, compute loss, get predictions
        model_optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        _, predicted = torch.max(outputs.data, 1)

        # Update statistics and loss
        acc = predicted == targets
        for j, index in enumerate(batch_inds):

            # Get index in original dataset (not sorted by forgetting)
            index_in_original_dataset = train_indx[index]

            # Compute missclassification margin
            output_correct_class = outputs.data[j, targets[j].item()]
            sorted_output, _ = torch.sort(outputs.data[j, :])
            if acc[j]:
                # Example classified correctly, highest incorrect class is 2nd largest output
                output_highest_incorrect_class = sorted_output[-2]
            else:
                # Example misclassified, highest incorrect class is max output
                output_highest_incorrect_class = sorted_output[-1]
            margin = output_correct_class.item(
            ) - output_highest_incorrect_class.item()

            # Add the statistics of the current training example to dictionary
            index_stats = example_stats.get(index_in_original_dataset,
                                            [[], [], []])
            index_stats[0].append(loss[j].item())
            index_stats[1].append(acc[j].sum().item())
            index_stats[2].append(margin)
            example_stats[index_in_original_dataset] = index_stats

        # Update loss, backward propagate, update optimizer
        loss = loss.mean()
        train_loss += loss.item()
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()
        loss.backward()
        model_optimizer.step()

        sys.stdout.write('\r')
        sys.stdout.write(
            '| Epoch [%3d/%3d] Iter[%3d/%3d]\t\tLoss: %.4f Acc@1: %.3f%%' %
            (epoch, args.epochs, batch_idx + 1,
             (len(trainset) // batch_size) + 1, loss.item(),
             100. * correct.item() / total))
        sys.stdout.flush()

        # Add training accuracy to dict
        index_stats = example_stats.get('train', [[], []])
        index_stats[1].append(100. * correct.item() / float(total))
        example_stats['train'] = index_stats


# Evaluate model predictions on heldout test data
#
# example_stats: dictionary containing statistics accumulated over every presentation of example
#
def test(epoch, model, device, example_stats):
    global best_acc
    test_loss = 0.
    correct = 0.
    total = 0.
    test_batch_size = 32

    model.eval()

    for batch_idx, batch_start_ind in enumerate(
            range(0, len(test_dataset.targets), test_batch_size)):

        # Get batch inputs and targets
        transformed_testset = []
        for ind in range(
                batch_start_ind,
                min(
                    len(test_dataset.targets),
                    batch_start_ind + test_batch_size)):
            transformed_testset.append(test_dataset.__getitem__(ind)[0])
        inputs = torch.stack(transformed_testset)
        targets = torch.LongTensor(
            np.array(
                test_dataset.targets)[batch_start_ind:batch_start_ind +
                                          test_batch_size].tolist())

        # Map to available device
        inputs, targets = inputs.to(device), targets.to(device)

        # Forward propagation, compute loss, get predictions
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss = loss.mean()
        test_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    # Add test accuracy to dict
    acc = 100. * correct.item() / total
    index_stats = example_stats.get('test', [[], []])
    index_stats[1].append(100. * correct.item() / float(total))
    example_stats['test'] = index_stats
    print("\n| Validation Epoch #%d\t\t\tLoss: %.4f Acc@1: %.2f%%" %
          (epoch, loss.item(), acc))

    # Save checkpoint when best model
    if acc > best_acc:
        # print('| Saving Best model...\t\t\tTop1 = %.2f%%' % (acc))
        # state = {
        #     'acc': acc,
        #     'epoch': epoch,
        # }
        # save_point = os.path.join(args.output_dir, 'checkpoint', args.dataset)
        # os.makedirs(save_point, exist_ok=True)
        # torch.save(state, os.path.join(save_point, save_fname + '.t7'))
        best_acc = acc


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', default='cifar10', help='dataset to use')
parser.add_argument('--model', default='ResNet18', help='model to use')
parser.add_argument(
    '--batch_size',
    type=int,
    default=128,
    help='input batch size for training (default: 128)')
parser.add_argument(
    '--epochs',
    type=int,
    default=200,
    help='number of epochs to train (default: 200)')
parser.add_argument(
    '--learning_rate', type=float, default=0.1, help='learning rate')
parser.add_argument(
    '--data_augmentation',
    action='store_true',
    default=False,
    help='augment data by flipping and cropping')
parser.add_argument(
    '--cutout', action='store_true', default=False, help='apply cutout')
parser.add_argument(
    '--n_holes',
    type=int,
    default=1,
    help='number of holes to cut out from image')
parser.add_argument(
    '--length', type=int, default=16, help='length of the holes')
parser.add_argument(
    '--no-cuda',
    action='store_true',
    default=False,
    help='enables CUDA training')
parser.add_argument(
    '--seed', type=int, default=1, help='random seed (default: 1)')
parser.add_argument(
    '--optimizer',
    default="sgd",
    help='optimizer to use, default is sgd. Can also use adam')
parser.add_argument(
    '--input_dir',
    default=None,
    help='directory where to read sorting file from')
parser.add_argument(
    '--output_dir', required=True, help='directory where to save results')

parser.add_argument('--data_path', type=str, required=True, help='path to the dataset')
parser.add_argument('--exp_id', type=int, required=True, help='shadow model id')
parser.add_argument('--num_shadow', type=int, required=True, help='total num of shadow models')
parser.add_argument('--num_canaries', type=int, default=512, help='num of noisy labels')
parser.add_argument('--lira_path', type=str, required=True, help='path to save LiRA files, e.g., in-out-split indices, canaries indices, and noise targets')
parser.add_argument('--avg_case', action='store_true', default=False, help='use average case in-out split')

# Enter all arguments that you want to be in the filename of the saved output
ordered_args = ['dataset', 'data_augmentation', 'cutout', 'seed']

# Parse arguments and setup name of output file with forgetting stats
args = parser.parse_args()
assert 0 <= args.exp_id < args.num_shadow
args.output_dir = os.path.join(args.output_dir, f'exp_{args.exp_id}')
args_dict = vars(args)
print(args_dict)
save_fname = '__'.join(
    '{}_{}'.format(arg, args_dict[arg]) for arg in ordered_args)

# Set appropriate devices
args.cuda = not args.no_cuda and torch.cuda.is_available()
use_cuda = args.cuda
device = torch.device("cuda" if use_cuda else "cpu")
cudnn.benchmark = True  # Should make training go faster for large models

# Set random seed for initialization
torch.manual_seed(args.seed)
if args.cuda:
    torch.cuda.manual_seed(args.seed)
npr.seed(args.seed)

# Image Preprocessing
normalize = transforms.Normalize(
    mean=[x / 255.0 for x in [125.3, 123.0, 113.9]],
    std=[x / 255.0 for x in [63.0, 62.1, 66.7]])

# Setup train transforms
train_transform = transforms.Compose([])
if args.data_augmentation:
    train_transform.transforms.append(transforms.RandomCrop(32, padding=4))
    train_transform.transforms.append(transforms.RandomHorizontalFlip())
train_transform.transforms.append(transforms.ToTensor())
train_transform.transforms.append(normalize)
if args.cutout:
    train_transform.transforms.append(
        Cutout(n_holes=args.n_holes, length=args.length))

# Setup test transforms
test_transform = transforms.Compose([transforms.ToTensor(), normalize])

os.makedirs(args.output_dir, exist_ok=True)

# Load the appropriate train and test datasets
if args.dataset == 'cifar10':
    num_classes = 10
    num_channel = 3
    train_dataset = datasets.CIFAR10(
        root=args.data_path,
        train=True,
        transform=train_transform,
        download=True)

    test_dataset = datasets.CIFAR10(
        root=args.data_path,
        train=False,
        transform=test_transform,
        download=True)
else:
    raise ValueError(f"Dataset {args.dataset} not supported")

if not args.avg_case:
    ''' mislabeling start '''
    # generate indices and noisy labels
    noisy_targets, shadow_in_indices, canary_indices = in_out_split_noisy(
        clean_train_ys=train_dataset.targets, 
        seed=0, 
        num_shadow=args.num_shadow,
        num_canaries=args.num_canaries,
    )
    # replace the labels of canaries with noisy labels
    assert len(noisy_targets) == len(train_dataset.targets)
    for i, noisy_target in enumerate(noisy_targets):
        train_dataset.targets[i] = noisy_target
    indices_path = os.path.join(args.lira_path, 'indices')
    noisy_targets_path = os.path.join(args.lira_path, 'noisy_targets.npy')
    canary_indices_path = os.path.join(args.lira_path, 'canary_indices.npy')
    if not os.path.exists(indices_path):
        os.makedirs(indices_path)
    selected_indices = np.array(shadow_in_indices[args.exp_id])
    if not os.path.exists(os.path.join(indices_path, f'indice_{args.exp_id}.npy')):
        np.save(os.path.join(indices_path, f'indice_{args.exp_id}.npy'), selected_indices)
    else:
        assert np.array_equal(selected_indices, np.load(os.path.join(indices_path, f'indice_{args.exp_id}.npy')))
    if not os.path.exists(noisy_targets_path):
        np.save(noisy_targets_path, np.array(noisy_targets))
    else:
        assert np.array_equal(noisy_targets, np.load(noisy_targets_path))
    if not os.path.exists(canary_indices_path):
        np.save(canary_indices_path, np.array(canary_indices))
    else:
        assert np.array_equal(canary_indices, np.load(canary_indices_path))
    ''' mislabeling end '''
else:
    shadow_in_indices = in_out_split_avg_case(
        dataset_size=len(train_dataset.targets), 
        seed=0, 
        num_shadow=args.num_shadow,
    )
    selected_indices = np.array(shadow_in_indices[args.exp_id])

    indices_path = os.path.join(args.lira_path, 'indices')
    if not os.path.exists(indices_path):
        os.makedirs(indices_path)
    if not os.path.exists(os.path.join(indices_path, f'indice_{args.exp_id}.npy')):
        np.save(os.path.join(indices_path, f'indice_{args.exp_id}.npy'), selected_indices)
    else:
        assert np.array_equal(selected_indices, np.load(os.path.join(indices_path, f'indice_{args.exp_id}.npy')))
    

# Get indices of examples that should be used for training
train_indx = selected_indices

# Reassign train data and labels
train_dataset.data = train_dataset.data[train_indx, :, :, :]
train_dataset.targets = np.array(train_dataset.targets)[train_indx].tolist()

print('Training on ' + str(len(train_dataset.targets)) + ' examples')

# Setup model
if args.model == 'ResNet18':
    model = ResNet18BN(channel=num_channel, num_classes=num_classes)
else:
    print('Specified model not supported')

# Setup loss
model = model.cuda()
criterion = nn.CrossEntropyLoss().cuda()
criterion.__init__(reduce=False)

# Setup optimizer
if args.optimizer == 'adam':
    model_optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
elif args.optimizer == 'sgd':
    model_optimizer = torch.optim.SGD(
        model.parameters(),
        lr=args.learning_rate,
        momentum=0.9,
        nesterov=True,
        weight_decay=5e-4)
    scheduler = MultiStepLR(
        model_optimizer, milestones=[60, 120, 160], gamma=0.2)
else:
    print('Specified optimizer not recognized. Options are: adam and sgd')

# Initialize dictionary to save statistics for every example presentation
example_stats = {}

best_acc = 0
elapsed_time = 0
for epoch in range(args.epochs):
    start_time = time.time()

    train(args, model, device, train_dataset, model_optimizer, epoch,
          example_stats)
    test(epoch, model, device, example_stats)

    epoch_time = time.time() - start_time
    elapsed_time += epoch_time
    print('| Elapsed time : %d:%02d:%02d' % (get_hms(elapsed_time)))

    # Update optimizer step
    if args.optimizer == 'sgd':
        scheduler.step(epoch)

    # Save the stats dictionary
    fname = os.path.join(args.output_dir, save_fname)
    with open(fname + "__stats_dict.pkl", "wb") as f:
        pickle.dump(example_stats, f)

    # Log the best train and test accuracy so far
    with open(fname + "__best_acc.txt", "w") as f:
        f.write('train test \n')
        f.write(str(max(example_stats['train'][1])))
        f.write(' ')
        f.write(str(max(example_stats['test'][1])))
